import cdtools
from cdtools.tools import plotting as p
from scipy.io import loadmat, savemat
from helper_functions import fourier_pad
import numpy as np
import torch as t
import argparse
from matplotlib import pyplot as plt

device='cuda'

to_pad = 200

rpi_filename = 'data/magnetic_rpi_lcp.cxi'

#rpi_filename = 'data/single_shot_rpi.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(rpi_filename,
                                                    cut_zeros=True)


dataset.inspect()

# We load the probe and background from a ptycho calibration
init_file = 'results/magnetic_ptycho_lcp_synthesized.mat'
ptycho_results = loadmat(init_file)
probe = t.as_tensor(ptycho_results['probe_full'])
background = t.as_tensor(ptycho_results['background_full'])

resolution = 210
savefile = f'results/single_shot_rpi_lcp_{resolution}.mat'

resolution = [resolution, resolution]

objs = []
for frame in range(len(dataset)):
    print('Frame idx', frame, 'of', len(dataset))
    
    # We create an RPI model from the dataset
    # Note that we explicitly use two incoherent object modes
    model = cdtools.models.RPI.from_dataset(
        dataset,
        probe,
        resolution,
        background=background, n_modes=2,
        initialization='uniform',
        weight_matrix=False,
        mask=dataset.mask)
    
    
    # We move everything to the right device
    model.to(device=device)
    dataset.get_as(device=device)
    
    
    # A few iterations of Adam works to initialize the reconstruction
    # The regularization is an L2 regularizer that empirically helps accelerate
    # convergence. By increasing the regularization on the second mode, we
    # drive the object reconstruction into the top mode, which is the result
    for loss in model.Adam_optimize(50, dataset, subset=[frame], lr=0.04,
                                    regularization_factor=[30,300]):
        print(model.report())
    
    # And we converge to the final result with L-BFGS.
    last_loss = None
    for loss in model.LBFGS_optimize(25, dataset, subset=[frame], lr=0.4, regularization_factor=[30, 300], line_search_fn='strong_wolfe'):
        print(model.report())
        if loss == last_loss:
            break
        else:
            last_loss = loss
        
            
            
    objs.append(model.obj[0].detach().cpu().numpy())
    
basis = model.obj_basis.detach().cpu().numpy()
objs = np.stack(objs)

savemat(savefile, {'objs': objs, 'basis' : basis})

